"""Stochastic variants of Lookup table based-strategies, trained with particle
swarm algorithms.

For the original see:
 https://gist.github.com/GDKO/60c3d0fd423598f3c4e4
"""

from typing import Any

from axelrod.action import Action, actions_to_str, str_to_actions

from axelrod.load_data_ import load_pso_tables

from axelrod.player import Player

from .lookerup import (
    EvolvableLookerUp,
    LookerUp,
    LookupTable,
    Plays,
    create_lookup_table_keys,
)

C, D = Action.C, Action.D

tables = load_pso_tables("pso_gambler.csv", directory="data")

class LookerUp(Player):
    """
    This strategy uses a LookupTable to decide its next action. If there is not
    enough history to use the table, it calls from a list of
    self.initial_actions.

    if self_depth=2, op_depth=3, op_openings_depth=5, LookerUp finds the last 2
    plays of self, the last 3 plays of opponent and the opening 5 plays of
    opponent. It then looks those up on the LookupTable and returns the
    appropriate action. If 5 rounds have not been played (the minimum required
    for op_openings_depth), it calls from self.initial_actions.

    LookerUp can be instantiated with a dictionary. The dictionary uses
    tuple(tuple, tuple, tuple) or Plays as keys. for example.

    - self_plays: depth=2
    - op_plays: depth=1
    - op_openings: depth=0::

        {Plays((C, C), (C), ()): C,
         Plays((C, C), (D), ()): D,
         Plays((C, D), (C), ()): D,  <- example below
         Plays((C, D), (D), ()): D,
         Plays((D, C), (C), ()): C,
         Plays((D, C), (D), ()): D,
         Plays((D, D), (C), ()): C,
         Plays((D, D), (D), ()): D}

    From the above table, if the player last played C, D and the opponent last
    played C (here the initial opponent play is ignored) then this round,
    the player would play D.

    The dictionary must contain all possible permutations of C's and D's.

    LookerUp can also be instantiated with `pattern=str/tuple` of actions, and::

        parameters=Plays(
            self_plays=player_depth: int,
            op_plays=op_depth: int,
            op_openings=op_openings_depth: int)

    It will create keys of len=2 ** (sum(parameters)) and map the pattern to
    the keys.

    initial_actions is a tuple such as (C, C, D). A table needs initial actions
    equal to max(self_plays depth, opponent_plays depth, opponent_initial_plays
    depth). If provided initial_actions is too long, the extra will be ignored.
    If provided initial_actions is too short, the shortfall will be made up
    with C's.

    Some well-known strategies can be expressed as special cases; for example
    Cooperator is given by the dict (All history is ignored and always play C)::

        {Plays((), (), ()) : C}


    Tit-For-Tat is given by (The only history that is important is the
    opponent's last play.)::

       {Plays((), (D,), ()): D,
        Plays((), (C,), ()): C}


    LookerUp's LookupTable defaults to Tit-For-Tat.  The initial_actions
    defaults to playing C.

    Names:

    - Lookerup: Original name by Martin Jones
    """

    name = "LookerUp"
    classifier = {
        "memory_depth": float("inf"),
        "stochastic": False,
        "long_run_time": False,
        "inspects_source": False,
        "manipulates_source": False,
        "manipulates_state": False,
    }

    default_tft_lookup_table = {
        Plays(self_plays=(), op_plays=(D,), op_openings=()): D,
        Plays(self_plays=(), op_plays=(C,), op_openings=()): C,
    }

    def __init__(
        self,
        lookup_dict: dict = None,
        initial_actions: tuple = None,
        pattern: Any = None,  # pattern is str or tuple of Action's.
        parameters: Plays = None,
    ) -> None:

        Player.__init__(self)
        self.parameters = parameters
        self.pattern = pattern
        self._lookup = self._get_lookup_table(lookup_dict, pattern, parameters)
        self._set_memory_depth()
        self.initial_actions = self._get_initial_actions(initial_actions)
        self._initial_actions_pool = list(self.initial_actions)

    @classmethod
    def _get_lookup_table(
        cls, lookup_dict: dict, pattern: Any, parameters: tuple
    ) -> LookupTable:
        if lookup_dict:
            return LookupTable(lookup_dict=lookup_dict)
        if pattern is not None and parameters is not None:
            if isinstance(pattern, str):
                pattern = str_to_actions(pattern)
            self_depth, op_depth, op_openings_depth = parameters
            return LookupTable.from_pattern(
                pattern, self_depth, op_depth, op_openings_depth
            )
        return LookupTable(default_tft_lookup_table)

    def _set_memory_depth(self):
        if self._lookup.op_openings_depth == 0:
            self.classifier["memory_depth"] = self._lookup.table_depth
        else:
            self.classifier["memory_depth"] = float("inf")

    def _get_initial_actions(self, initial_actions: tuple) -> tuple:
        """Initial actions will always be cut down to table_depth."""
        table_depth = self._lookup.table_depth
        if not initial_actions:
            return tuple([C] * table_depth)
        initial_actions_shortfall = table_depth - len(initial_actions)
        if initial_actions_shortfall > 0:
            return initial_actions + tuple([C] * initial_actions_shortfall)
        return initial_actions[:table_depth]

    def strategy(self, opponent: Player) -> Reaction:
        turn_index = len(opponent.history)
        while turn_index < len(self._initial_actions_pool):
            return self._initial_actions_pool[turn_index]

        player_last_n_plays = get_last_n_plays(
            player=self, depth=self._lookup.player_depth
        )
        opponent_last_n_plays = get_last_n_plays(
            player=opponent, depth=self._lookup.op_depth
        )
        opponent_initial_plays = tuple(
            opponent.history[: self._lookup.op_openings_depth]
        )

        return self._lookup.get(
            player_last_n_plays, opponent_last_n_plays, opponent_initial_plays
        )

    @property
    def lookup_dict(self):
        return self._lookup.dictionary

    def lookup_table_display(
        self, sort_by: tuple = ("op_openings", "self_plays", "op_plays")
    ) -> str:
        """
        Returns a string for printing lookup_table info in specified order.

        :param sort_by: only_elements='self_plays', 'op_plays', 'op_openings'
        """
        return self._lookup.display(sort_by=sort_by)

class Gambler(LookerUp):
    """
    A stochastic version of LookerUp which will select randomly an action in
    some cases.

    Names:

    - Gambler: Original name by Georgios Koutsovoulos
    """

    name = "Gambler"
    classifier = {
        "memory_depth": float("inf"),
        "stochastic": True,
        "long_run_time": False,
        "inspects_source": False,
        "manipulates_source": False,
        "manipulates_state": False,
    }

    def strategy(self, opponent: Player) -> Action:
        """Actual strategy definition that determines player's action."""
        actions_or_float = super(Gambler, self).strategy(opponent)
        if isinstance(actions_or_float, Action):
            return actions_or_float
        return self._random.random_choice(actions_or_float)

class EvolvableLookerUp(LookerUp, EvolvablePlayer):
    name = "EvolvableLookerUp"

    def __init__(
        self,
        lookup_dict: dict = None,
        initial_actions: tuple = None,
        pattern: Any = None,  # pattern is str or tuple of Action's.
        parameters: Plays = None,
        mutation_probability: float = None,
        seed: int = None,
    ) -> None:
        EvolvablePlayer.__init__(self, seed=seed)
        (
            lookup_dict,
            initial_actions,
            pattern,
            parameters,
            mutation_probability,
        ) = self._normalize_parameters(
            lookup_dict,
            initial_actions,
            pattern,
            parameters,
            mutation_probability,
        )
        LookerUp.__init__(
            self,
            lookup_dict=lookup_dict,
            initial_actions=initial_actions,
            pattern=pattern,
            parameters=parameters,
        )
        self.mutation_probability = mutation_probability
        self.overwrite_init_kwargs(
            lookup_dict=lookup_dict,
            initial_actions=initial_actions,
            pattern=pattern,
            parameters=parameters,
            mutation_probability=mutation_probability,
        )

    def _normalize_parameters(
        self,
        lookup_dict=None,
        initial_actions=None,
        pattern=None,
        parameters=None,
        mutation_probability=None,
    ):
        if lookup_dict and initial_actions:
            # Compute the associated pattern and parameters
            # Map the table keys to namedTuple Plays
            lookup_table = self._get_lookup_table(
                lookup_dict, pattern, parameters
            )
            lookup_dict = lookup_table.dictionary
            parameters = (
                lookup_table.player_depth,
                lookup_table.op_depth,
                lookup_table.op_openings_depth,
            )
            pattern = tuple(v for k, v in sorted(lookup_dict.items()))
        elif pattern and parameters and initial_actions:
            # Compute the associated lookup table
            lookup_table = self._get_lookup_table(
                lookup_dict, pattern, parameters
            )
            plays, op_plays, op_start_plays = parameters
            lookup_table = self._get_lookup_table(
                lookup_dict, pattern, parameters
            )
            lookup_dict = lookup_table.dictionary
        elif parameters:
            # Generate a random pattern and (maybe) initial actions
            plays, op_plays, op_start_plays = parameters
            pattern, lookup_table = self.random_params(
                plays, op_plays, op_start_plays
            )
            lookup_dict = lookup_table.dictionary
            if not initial_actions:
                num_actions = max([plays, op_plays, op_start_plays])
                initial_actions = tuple(
                    [self._random.choice((C, D)) for _ in range(num_actions)]
                )
        else:
            raise InsufficientParametersError(
                "Insufficient Parameters to instantiate EvolvableLookerUp"
            )
        # Normalize pattern
        if isinstance(pattern, str):
            pattern = str_to_actions(pattern)
        pattern = tuple(pattern)
        if mutation_probability is None:
            plays, op_plays, op_start_plays = parameters
            keys = create_lookup_table_keys(plays, op_plays, op_start_plays)
            mutation_probability = 2.0 / len(keys)
        return (
            lookup_dict,
            initial_actions,
            pattern,
            parameters,
            mutation_probability,
        )

    def random_value(self):
        return self._random.choice(actions)

    def random_params(self, plays, op_plays, op_start_plays):
        keys = create_lookup_table_keys(plays, op_plays, op_start_plays)
        # To get a pattern, we just randomly pick between C and D for each key
        pattern = [self.random_value() for _ in keys]
        table = dict(zip(keys, pattern))
        return pattern, LookupTable(table)

    @classmethod
    def mutate_value(cls, value):
        return value.flip()

    def mutate_table(self, table, mutation_probability):
        randoms = self._random.random(len(table.keys()))
        # Flip each value with a probability proportional to the mutation rate
        for i, (history, move) in enumerate(table.items()):
            if randoms[i] < mutation_probability:
                table[history] = self.mutate_value(move)
        return table

    def mutate(self):
        lookup_dict = self.mutate_table(
            self.lookup_dict, self.mutation_probability
        )
        # Add in starting moves
        initial_actions = list(self.initial_actions)
        for i in range(len(initial_actions)):
            r = self._random.random()
            if r < self.mutation_probability:
                initial_actions[i] = initial_actions[i].flip()
        return self.create_new(
            lookup_dict=lookup_dict,
            initial_actions=tuple(initial_actions),
        )

    def crossover(self, other):
        if other.__class__ != self.__class__:
            raise TypeError(
                "Crossover must be between the same player classes."
            )
        lookup_dict = crossover_dictionaries(
            self.lookup_dict, other.lookup_dict, self._random
        )
        return self.create_new(lookup_dict=lookup_dict)

class EvolvableGambler(Gambler, EvolvableLookerUp):
    name = "EvolvableGambler"

    def __init__(
        self,
        lookup_dict: dict = None,
        initial_actions: tuple = None,
        pattern: Any = None,  # pattern is str or tuple of Actions.
        parameters: Plays = None,
        mutation_probability: float = None,
        seed: int = None,
    ) -> None:
        EvolvableLookerUp.__init__(
            self,
            lookup_dict=lookup_dict,
            initial_actions=initial_actions,
            pattern=pattern,
            parameters=parameters,
            mutation_probability=mutation_probability,
            seed=seed,
        )
        self.pattern = list(self.pattern)
        Gambler.__init__(
            self,
            lookup_dict=self.lookup_dict,
            initial_actions=self.initial_actions,
            pattern=self.pattern,
            parameters=self.parameters,
        )
        self.overwrite_init_kwargs(
            lookup_dict=self.lookup_dict,
            initial_actions=self.initial_actions,
            pattern=self.pattern,
            parameters=self.parameters,
            mutation_probability=self.mutation_probability,
        )

    # The mutate and crossover methods are mostly inherited from EvolvableLookerUp, except for the following
    # modifications.

    def random_value(self) -> float:
        return self._random.random()

    def mutate_value(self, value: float) -> float:
        ep = self._random.uniform(-1, 1) / 4
        value += ep
        if value < 0:
            value = 0
        elif value > 1:
            value = 1
        return value

    def receive_vector(self, vector):
        """Receives a vector and updates the player's pattern. Ignores extra parameters."""
        self.pattern = vector
        self_depth, op_depth, op_openings_depth = self.parameters
        self._lookup = LookupTable.from_pattern(
            self.pattern, self_depth, op_depth, op_openings_depth
        )

    def create_vector_bounds(self):
        """Creates the bounds for the decision variables. Ignores extra parameters."""
        size = len(self.pattern)
        lb = [0.0] * size
        ub = [1.0] * size
        return lb, ub